import numpy as np
import random, argparse
import torch, transformers, datasets
import tqdm, pprint, logging
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler


supported_models = [
            'meta-llama/Llama-2-7b-hf'
            ]
supported_datasets = ['wikitext2']

# These flags disable using TensorFloat-32 tensor cores (to avoid numerical issues)
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
DEV = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')


def set_seed(seed):
    np.random.seed(seed)
    torch.random.manual_seed(seed)
    random.seed(seed)

def map_tensors(obj,  device, dtype=None):
    """Recursively map tensors to device and dtype."""
    if isinstance(obj, torch.Tensor):
        if device is not None:
            obj = obj.to(device=device)
        if dtype is not None:
            obj = obj.to(dtype=dtype)
        return obj
    elif isinstance(obj, (list, tuple)):
        return type(obj)(map_tensors(x, device, dtype) for x in obj)
    elif isinstance(obj, dict):
        return {k: map_tensors(v, device, dtype) for k, v in obj.items()}  # type: ignore
    else:
        return obj


def parser_gen():
    parser = argparse.ArgumentParser()

    # General Arguments
    parser.add_argument('--model', type=str, default='meta-llama/Llama-2-7b-hf',
                        help='Model to load;', choices=supported_models)
    parser.add_argument('--seed', type=int, default=0, help='Random Seed for HuggingFace and PyTorch')
    parser.add_argument('--dataset', type=str, default='wikitext2',
                        help='Dataset for Evaluation (default: wikitext2)', choices=supported_datasets,)
    parser.add_argument('--hf_token', type=str, default=None)
    parser.add_argument('--ppl_bsz', type=int, default=2,
                        help='Batch-size for PPL evaluation (default:32)')
    parser.add_argument('--ppl_seq_len', type=int, default=2048,
                        help='Sequence Length for PPL evaluation (default:2048)')


    # Rotation Arguments
    parser.add_argument('--had', action=argparse.BooleanOptionalAction, default=False, 
                        help='''Wrap model with Hadamards''')
   
    # WandB Arguments
    parser.add_argument('--wandb', action=argparse.BooleanOptionalAction, default=False)
    parser.add_argument('--wandb_id', type=str, default=None)
    parser.add_argument('--wandb_project', type=str, default=None)

    args = parser.parse_args()

    logging.info('Arguments: ')
    logging.info(pprint.pformat(vars(args)))
    logging.info('--' * 30)
    return args

def skip(*args, **kwargs):
    # This is a helper function to save time during the initialization! 
    pass

# @do_not_initialize
def get_model_and_tokenizer(
    model_name,
    hf_token=None,
    seq_len=2048
):

    print(
        f"Loading %s config %s from %s",
        model_name,
        'Hugging Face'
    )

    torch.nn.init.kaiming_uniform_ = skip
    torch.nn.init.uniform_ = skip
    torch.nn.init.normal_ = skip
    model = transformers.LlamaForCausalLM.from_pretrained(model_name, torch_dtype='auto',
                                                          low_cpu_mem_usage=True)
    model.seqlen = seq_len
    logging.info('---> Loading {} Model with seq_len: {}'.format(model_name, model.seqlen))

    model.eval()  # This switches off dropouts and other stuffs.
    model.use_cache = False
    
    tokenizer = transformers.AutoTokenizer.from_pretrained(model_name,
                                                           use_fast=True,
                                                           token=hf_token)
    return model, tokenizer


def prepare_test_dataloader(
    dataset: datasets.Dataset, 
    tokenizer: transformers.PreTrainedTokenizerBase, 
    seqlen: int = 2048, 
    batch_size: int = 1
) -> DataLoader[dict[str, torch.Tensor]]:
    """
    Get a DataLoader from a test dataset. This dataloader should be used when comparing WikiText2 perplexities with other papers, e.g. SparseGPT (arxiv.org/abs/2301.00774).

    Args:
        dataset: The dataset to create a dataloader from.
        tokenizer: The tokenizer to use.
        seqlen: The sequence length of sequences in the dataset.
        batch_size: The batch size.

    Returns:
        A DataLoader.
    """

    print(f"Preparing test dataloader")

    class TestDataset(Dataset):
        def __init__(self, ds, tokenizer, seqlen=2048):
            """Tokenize the entire dataset and reshape it into sequences of length seqlen."""
            tokenized_ds = tokenizer("\n\n".join(ds['text']), return_tensors='pt')
            nsamples = tokenized_ds.input_ids.numel() // seqlen

            input_ids = tokenized_ds.input_ids[0, : nsamples * seqlen]
            input_ids = input_ids.reshape(nsamples, seqlen)
            attn_mask = tokenized_ds.attention_mask[0, : nsamples * seqlen]
            attn_mask = attn_mask.reshape(nsamples, seqlen)

            self.input_ids = input_ids
            self.attn_mask = attn_mask

        def __getitem__(self, idx):
            return {"input_ids": self.input_ids[idx], "attention_mask": self.attn_mask[idx]}

        def __len__(self):
            return len(self.input_ids)

    test_ds = TestDataset(dataset, tokenizer, seqlen)
    loader = DataLoader(test_ds, batch_size=batch_size)
    print(f"Preparing test dataloader done")
    return loader


def get_dataset(name: str) -> datasets.DatasetDict:
    """
    Get the dataset from the HuggingFace datasets library.

    Args:
        name: The name of the HuggingFace dataset to load. Must be one of "wikitext2", "ptb", "c4" or "alpaca".

    Returns:
        The dataset.
    """
    print(f"Loading dataset: {name}")

    ds_properties = {
        "wikitext2": {"path": "wikitext", "config_name": "wikitext-2-raw-v1"},
        "ptb": {"path": "ptb_text_only", "config_name": "penn_treebank"},
        "c4": {
            "path": "allenai/c4",
            "config_name": "allenai--c4",
            "data_files": {
                "train": "en/c4-train.00000-of-01024.json.gz",
                "validation": "en/c4-validation.00000-of-00008.json.gz",
            },
            "cols_to_remove": ['url', 'timestamp'],
        },
        "alpaca": {"path": "tatsu-lab/alpaca", "cols_to_remove": ['input', 'output', 'instruction']},
    }

    if name not in ds_properties:
        raise NotImplementedError("The provided dataset is not supported")

    properties = ds_properties[name]
    ds = datasets.load_dataset(
        properties["path"], name=properties.get("config_name"), data_files=properties.get("data_files")
    )

    if "cols_to_remove" in properties:
        ds = ds.remove_columns(properties["cols_to_remove"])

    # if alpaca, create a test and validation set from the training set
    if name == "alpaca":
        ds = ds["train"].train_test_split(test_size=0.2, seed=42)
        temp_ds = ds.pop("test")
        temp_ds = temp_ds.train_test_split(test_size=0.5, seed=42)
        ds["test"] = temp_ds["train"]
        ds["validation"] = temp_ds["test"]

    print("Loading dataset done")
    return ds


@torch.no_grad()
def evaluate_ppl(
    model, 
    pad_token_id,
    testloader
) -> float:
    """
    Evaluate the model's perplexity on the test set using batch processing.
    It is expected that model is already on the correct device.
    """


    model.eval()

    if pad_token_id:
        loss_fn = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=pad_token_id)
    else:
        loss_fn = torch.nn.CrossEntropyLoss(reduction="none")

    nlls = []

    logging.info("Evaluating perplexity...")
    for batch in tqdm.tqdm(testloader, desc="Evaluating perplexity", unit="batch"):
        logging.debug(f"Evaluating batch {len(nlls)}")
        batch = map_tensors(batch, 'cuda:0')
        
        logits = model(**batch).logits

        # shift outputs and labels autoregressively.
        logits = logits[:, :-1, :]
        shift_labels = batch["input_ids"][:, 1:]

        # CrossEntropyLoss demands data dimension is dimension 1.
        nll = loss_fn(logits.permute(0, 2, 1), shift_labels).float()

        mask = shift_labels != loss_fn.ignore_index
        nll_means = (nll * mask).sum(dim=1) / mask.sum(dim=1)
        nlls.append(nll_means)

    nlls_tensor = torch.cat(nlls)
    ppl = torch.exp(nlls_tensor.mean())


    return ppl.item()

def main():
    args = parser_gen()
    if args.wandb:
        import wandb
        wandb.init(project=args.wandb_project, entity=args.wandb_id)
        wandb.config.update(args)
        

    transformers.set_seed(args.seed)
    model, tokenizer = get_model_and_tokenizer(
            args.model,
            hf_token=args.hf_token,
            seq_len=args.ppl_seq_len
        )
    if args.had:
        import hadamard
        hadamard.wrap_model(model)
    model.eval()
    dataset = get_dataset(args.dataset)
    test_dataset = dataset["test"]
    test_loader = prepare_test_dataloader(
            dataset=test_dataset, 
            tokenizer=tokenizer, 
            batch_size=args.ppl_bsz
        )
    
    dataset_ppl = evaluate_ppl(model.cuda(),
                               model.config.pad_token_id, test_loader)
    
    print(f'Loaded model perplexity: {dataset_ppl}')
    if args.wandb:
            wandb.log({"original_ppl": dataset_ppl})
    
    


if __name__ == '__main__':
    main()